import re
from typing import Dict, Optional

from loguru import logger
from pydantic import PositiveInt

from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import (
    get_model,
    prepare_model,
    update_sampling_params,
)
from data_juicer.utils.ray_utils import is_ray_mode

torch = LazyLoader("torch")
vllm = LazyLoader("vllm")

OP_NAME = "optimize_qa_mapper"


@OPERATORS.register_module(OP_NAME)
class OptimizeQAMapper(Mapper):
    """Mapper to optimize question-answer pairs.

    This operator refines and enhances the quality of question-answer pairs. It uses a
    Hugging Face model to generate more detailed and accurate questions and answers. The
    input is formatted using a template, and the output is parsed using a regular
    expression. The system prompt, input template, and output pattern can be customized. If
    VLLM is enabled, the operator accelerates inference on CUDA devices."""

    # avoid leading whitespace
    DEFAULT_SYSTEM_PROMPT = (
        "请优化输入的问答对，使【问题】和【回答】都更加详细、准确。"
        "必须按照以下标记格式，直接输出优化后的问答对：\n"
        "【问题】\n"
        "优化后的问题\n"
        "【回答】\n"
        "优化后的回答"
    )
    DEFAULT_INPUT_TEMPLATE = "以下是原始问答对：\n{}"
    DEFAULT_QA_PAIR_TEMPLATE = "【问题】\n{}\n【回答】\n{}"
    DEFAULT_OUTPUT_PATTERN = r".*?【问题】\s*(.*?)\s*【回答】\s*(.*)"

    _accelerator = "cuda"

    def __init__(
        self,
        api_or_hf_model: str = "Qwen/Qwen2.5-7B-Instruct",
        is_hf_model: bool = True,
        *,
        api_endpoint: Optional[str] = None,
        response_path: Optional[str] = None,
        system_prompt: Optional[str] = None,
        input_template: Optional[str] = None,
        qa_pair_template: Optional[str] = None,
        output_pattern: Optional[str] = None,
        try_num: PositiveInt = 3,
        enable_vllm: bool = False,
        model_params: Optional[Dict] = None,
        sampling_params: Optional[Dict] = None,
        **kwargs,
    ):
        """
        Initialization method.

        :param api_or_hf_model: API or huggingface model name.
        :param is_hf_model: If true, use huggingface model. Otherwise, use API.
        :param api_endpoint: URL endpoint for the API.
        :param response_path: Path to extract content from the API response.
            Defaults to 'choices.0.message.content'.
        :param system_prompt: System prompt for guiding the optimization task.
        :param input_template: Template for building the input for the model.
            Please make sure the template contains one placeholder '{}', which
            corresponds to the question and answer pair generated by
            param `qa_pair_template`.
        :param qa_pair_template: Template for formatting the question and
            answer pair. Please make sure the template contains two
            '{}' to format question and answer.
        :param output_pattern: Regular expression pattern to extract question
            and answer from model response.
        :param try_num: The number of retry attempts when there is an API
            call error or output parsing error.
        :param enable_vllm: Whether to use VLLM for inference acceleration.
        :param model_params: Parameters for initializing the model.
        :param sampling_params: Sampling parameters for text generation (e.g.,
            {'temperature': 0.9, 'top_p': 0.95}).
        :param kwargs: Extra keyword arguments.
        """
        super().__init__(**kwargs)

        self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
        self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
        self.qa_pair_template = qa_pair_template or self.DEFAULT_QA_PAIR_TEMPLATE
        self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN

        self.try_num = try_num

        self.is_hf_model = is_hf_model
        self.enable_vllm = enable_vllm
        model_params = model_params or {}
        sampling_params = sampling_params or {}

        sampling_params = update_sampling_params(sampling_params, api_or_hf_model, self.enable_vllm)

        if enable_vllm:
            if not is_ray_mode():
                # cannot initialize vllm replicas on different GPUs
                self.num_proc = 1
            self.model_key = prepare_model(
                model_type="vllm", pretrained_model_name_or_path=api_or_hf_model, **model_params
            )
            self.sampling_params = vllm.SamplingParams(**sampling_params)
        elif is_hf_model:
            self.model_key = prepare_model(
                model_type="huggingface",
                pretrained_model_name_or_path=api_or_hf_model,
                return_pipe=True,
                **model_params,
            )
            self.sampling_params = sampling_params
        else:
            self.sampling_params = sampling_params

            self.model_key = prepare_model(
                model_type="api",
                model=api_or_hf_model,
                endpoint=api_endpoint,
                response_path=response_path,
                **model_params,
            )

    def build_input(self, sample):
        qa_pair = self.qa_pair_template.format(sample[self.query_key], sample[self.response_key])
        input_prompt = self.input_template.format(qa_pair)
        return input_prompt

    def parse_output(self, raw_output):
        logger.debug(raw_output)
        match = re.match(self.output_pattern, raw_output, re.DOTALL)
        if match:
            return match.group(1).strip(), match.group(2).strip()
        else:
            return None, None

    def process_single(self, sample, rank=None):
        if self.enable_vllm or self.is_hf_model:
            model, _ = get_model(self.model_key, rank, self.use_cuda())
        else:
            model = get_model(self.model_key, rank, self.use_cuda())

        input_prompt = self.build_input(sample)
        messages = [{"role": "system", "content": self.system_prompt}, {"role": "user", "content": input_prompt}]

        parsed_q, parsed_a = None, None
        for _ in range(self.try_num):
            try:
                if self.enable_vllm:
                    response = model.chat(messages, self.sampling_params)
                    output = response[0].outputs[0].text
                elif self.is_hf_model:
                    # model is pipe
                    response = model(messages, return_full_text=False, **self.sampling_params)
                    output = response[0]["generated_text"]
                else:
                    output = model(messages, **self.sampling_params)
                parsed_q, parsed_a = self.parse_output(output)
                if parsed_q or parsed_a:
                    break
            except Exception as e:
                logger.warning(f"Exception: {e}")

        if parsed_q:
            sample[self.query_key] = parsed_q
        if parsed_a:
            sample[self.response_key] = parsed_a

        return sample
